# Alexander F. Gazmararian
# agazmararian@gmail.com

library(tidyverse)
library(tidylog)
library(here)
library(jsonlite)
library(httr)
library(openai)
library(progress)

new_annotation <- FALSE

source(here("R", "load_functions.R"))

# Function to load most recent annotations from api_cache
load_most_recent_annotations <- function(filename = "annotated_statements.csv", robustness_check = NULL) {
  # Adjust filename for robustness checks
  if (!is.null(robustness_check)) {
    filename <- sprintf("annotated_statements_%s.csv", robustness_check)
  }
  
  # Determine cache base based on whether this is a robustness check
  if (!is.null(robustness_check)) {
    cache_base <- here("data", "cache", "annotations", "robustness", robustness_check)
  } else {
    cache_base <- here("data", "cache", "annotations")
  }
  
  if (dir.exists(cache_base)) {
    # Get all date folders
    date_folders <- list.dirs(cache_base, full.names = FALSE, recursive = FALSE)
    date_folders <- date_folders[grepl("^\\d{8}$", date_folders)]
    
    if (length(date_folders) > 0) {
      # Sort and get most recent
      most_recent <- max(date_folders)
      file_path <- file.path(cache_base, most_recent, filename)
      
      if (file.exists(file_path)) {
        if (!is.null(robustness_check)) {
          message(sprintf("Loading robustness check annotations (%s) from %s", robustness_check, file_path))
        } else {
          message(sprintf("Loading annotations from %s", file_path))
        }
        return(read_csv(file_path, show_col_types = FALSE))
      }
    }
  }
}

# Function to safely call OpenAI API with retries and error handling
safe_chat_completion <- function(model, messages, temperature = 0, max_retries = 3, sleep_base = 2) {
  for (attempt in 1:max_retries) {
    result <- tryCatch({
      response <- create_chat_completion(
        model = model,
        messages = messages,
        temperature = temperature
      )
      if (is.null(response) || is.null(response$choices) || length(response$choices) == 0) {
        NULL
      } else {
        response$choices$message.content
      }
    }, error = function(e) {
      message(sprintf("[OpenAI API] Error on attempt %d: %s", attempt, e$message))
      if (grepl("rate limit|429", e$message, ignore.case = TRUE)) {
        Sys.sleep(sleep_base ^ attempt) # Exponential backoff for rate limits
      } else {
        Sys.sleep(1)
      }
      NULL
    })
    if (!is.null(result)) return(result)
  }
  message("[OpenAI API] All retries failed. Returning NA.")
  NA_character_
}

#' Annotate Statements with OpenAI GPT
#'
#' Annotates a data frame of statements using a two-stage GPT-based classification and coding process.
#'
#' @param df_in Data frame. The input data containing statements to annotate.
#' @param input_text Character. Name of the column in `df_in` containing the statement text. Default: "statement".
#' @param id_col Character. Name of the column in `df_in` containing unique statement IDs. Default: "statement_id".
#' @param speaker Character. Name of the column in `df_in` with the speaker's name. Default: "speaker_name".
#' @param role_col Character. Name of the column in `df_in` with the speaker's role. Default: "speaker_role".
#' @param release_type Character. Name of the column in `df_in` with the type of press release or statement. Default: "release_type".
#' @param state_col Character. Name of the column in `df_in` with the state information. Default: "state".
#' @param district_col Character. Name of the column in `df_in` with the district information. Default: "district".
#' @param city_col Character. Name of the column in `df_in` with the city information. Default: "city".
#' @param model_stage1 Character. OpenAI model name for the first (binary classification) stage. Default: "gpt-3.5-turbo-0125".
#' @param model_stage2 Character. OpenAI model name for the second (main coding) stage. Default: "gpt-4o-mini".
#' @param temp Numeric. Temperature parameter for the OpenAI API (controls randomness). Default: 0.
#' @param sleep_sec Numeric. Seconds to sleep between API calls (to avoid rate limits). Default: 1.
#' @param reset Logical. If TRUE, deletes any existing intermediate results and starts fresh. Default: FALSE.
#' @param save_every Integer. Save intermediate results to disk every N iterations (default: 10). Helps improve performance for large datasets while protecting against data loss.
#' @param codebook_path Character. Path to the codebook file to use for annotations. Default: "R/annotation/codebook.md".
#' @param robustness_check Character. Optional identifier for robustness checks (e.g., "alt_codebook"). If provided, outputs will be saved with this identifier to avoid confusion with primary annotations. Default: NULL.
#' @param probe_template_path Character. Path to the stage-1 probe template file. Default: "R/annotation/probe_template.md".
#'
#' @return A list containing:
#'   - `annotations`: A data frame with the original statements and their coded annotations
#'   - `new_work_done`: Logical indicating whether new API calls were made
annotate_statements <- function(
    df_in,
    input_text = "statement",
    id_col = "statement_id",
    speaker = "speaker_name",
    role_col = "speaker_role",
    release_type = "release_type",
    state_col = "state",
    district_col = "district",
    city_col = "city",
    model_stage1 = "gpt-3.5-turbo-0125",
    model_stage2 = "gpt-4o-mini",
    temp = 0,
    sleep_sec = 1,
    reset = FALSE,
    save_every = 10,
    codebook_path = NULL,
    robustness_check = NULL,
    probe_template_path = NULL) {

    # Input validation
    if (!is.data.frame(df_in)) {
      stop("df_in must be a data frame")
    }
    
    required_cols <- c(input_text, id_col, speaker, role_col)
    missing_cols <- setdiff(required_cols, names(df_in))
    if (length(missing_cols) > 0) {
      stop("Missing required columns: ", paste(missing_cols, collapse = ", "))
    }

    # Validate numeric parameters
    if (!is.numeric(temp) || temp < 0 || temp > 1) {
      stop("temp must be a number between 0 and 1")
    }
    if (!is.numeric(sleep_sec) || sleep_sec < 0) {
      stop("sleep_sec must be a positive number")
    }
    if (!is.numeric(save_every) || save_every < 1) {
      stop("save_every must be a positive integer")
    }

    log_message("INFO", "Starting annotation process")
    log_message("INFO", sprintf("Processing %d statements", nrow(df_in)))
    
    # Load codebook
    if (is.null(codebook_path)) {
      codebook_path <- here("R", "annotation", "codebook.md")
    }
    if (!file.exists(codebook_path)) {
      stop("Codebook not found at: ", codebook_path)
    }
    
    # Load probe template
    if (is.null(probe_template_path)) {
      probe_template_path <- here("R", "annotation", "probe_template.md")
    }
    if (!file.exists(probe_template_path)) {
      stop("Probe template not found at: ", probe_template_path)
    }
    
    # Add robustness check info to logging
    if (!is.null(robustness_check)) {
      log_message("INFO", sprintf("Running robustness check: %s", robustness_check))
      log_message("INFO", sprintf("Using alternative codebook: %s", codebook_path))
      log_message("INFO", sprintf("Using alternative probe template: %s", probe_template_path))
    } else {
      log_message("INFO", "Running primary annotation")
      log_message("INFO", sprintf("Using codebook: %s", codebook_path))
      log_message("INFO", sprintf("Using probe template: %s", probe_template_path))
    }
    
    system_prompt <- read_file(codebook_path) |>
      stringi::stri_enc_toutf8() |> 
      trimws()
    
    # Load probe template
    probe_template <- read_file(probe_template_path) |>
      stringi::stri_enc_toutf8() |> 
      trimws()

    # -------- stage-1 binary prompts --------
    # Create probe function using template
    probe_sys <- function(law) {
      # Replace {LAW_NAME} placeholder in template with the specific law
      gsub("\\{LAW_NAME\\}", law, probe_template)
    }

    # Path to save intermediate results with date organization
    date_folder <- format(Sys.Date(), "%Y%m%d")
    
    # Create directory structure based on whether this is a robustness check
    if (!is.null(robustness_check)) {
      cache_dir <- here("data", "cache", "annotations", "robustness", robustness_check, date_folder)
      save_path <- file.path(cache_dir, "intermediate_results.rds")
    } else {
      cache_dir <- here("data", "cache", "annotations", date_folder)
      save_path <- file.path(cache_dir, "intermediate_results.rds")
    }

    # Handle reset and resume logic
    if (reset) {
      # If reset is TRUE, start fresh regardless of existing files
      log_message("INFO", "Reset requested - starting fresh")
      # Create cache directory if needed
      if (!dir.exists(cache_dir)) {
        dir.create(cache_dir, recursive = TRUE)
      }
      if (file.exists(save_path)) {
        log_message("INFO", "Removing existing intermediate results")
        unlink(save_path)
      }
      results <- vector("list", nrow(df_in))
      start_i <- 1
    } else {
      # If reset is FALSE, try to resume from existing results
      # First priority: check for intermediate results from today
      if (file.exists(save_path)) {
        log_message("INFO", "Found existing intermediate results - attempting to resume")
        results <- readRDS(save_path)
        # Find the last non-null result
        completed_indices <- which(!sapply(results, is.null))
        if (length(completed_indices) > 0) {
          start_i <- max(completed_indices) + 1
          log_message("INFO", sprintf("Resuming from intermediate results at row %d", start_i))
        } else {
          start_i <- 1
          log_message("INFO", "No completed intermediate results found - starting from beginning")
        }
      } else {
        # Second priority: check for completed annotations from previous dates
        log_message("INFO", "No intermediate results found for today - checking for previous completed annotations")
        previous_annotations <- tryCatch({
          load_most_recent_annotations(robustness_check = robustness_check)
        }, error = function(e) {
          log_message("INFO", sprintf("No previous annotations found: %s", e$message))
          NULL
        })
        
        if (!is.null(previous_annotations)) {
          log_message("INFO", sprintf("Found %d previous annotations - will resume from where they left off", nrow(previous_annotations)))
          # Initialize results list and mark completed annotations
          results <- vector("list", nrow(df_in))
          
          # Match previous annotations to current input data by ID
          # Handle both old cache files (with 'id') and new ones (with 'statement_id')
          id_column_name <- if ("id" %in% names(previous_annotations)) "id" else "statement_id"
          
          for (i in seq_len(nrow(df_in))) {
            current_id <- df_in[[id_col]][i]
            matching_row <- previous_annotations[previous_annotations[[id_column_name]] == current_id, ]
            
            if (nrow(matching_row) > 0) {
              # Create data frame matching the expected structure
              results[[i]] <- data.frame(
                statement = df_in[[input_text]][i],
                gives_credit = as.integer(matching_row$gives_credit[1]),
                credit_biden = as.integer(matching_row$credit_biden[1]),
                credit_senate = as.integer(matching_row$credit_senate[1]),
                credit_us_rep = as.integer(matching_row$credit_us_rep[1]),
                credit_governor = as.integer(matching_row$credit_governor[1]),
                credit_local = as.integer(matching_row$credit_local[1]),
                credit_dem = as.integer(matching_row$credit_dem[1]),
                credit_gop = as.integer(matching_row$credit_gop[1]),
                credit_ira = as.integer(matching_row$credit_ira[1]),
                credit_bil = as.integer(matching_row$credit_bil[1]),
                id = current_id
              )
            }
          }
          
          # Find where to resume
          completed_indices <- which(!sapply(results, is.null))
          if (length(completed_indices) > 0) {
            start_i <- max(completed_indices) + 1
            log_message("INFO", sprintf("Resuming from previous work at row %d", start_i))
          } else {
            start_i <- 1
            log_message("INFO", "No matching previous annotations found - starting from beginning")
          }
        } else {
          log_message("INFO", "No existing results found - starting new annotation process")
          results <- vector("list", nrow(df_in))
          start_i <- 1
        }
      }
    }

    # Check if we're already done
    if (start_i > nrow(df_in)) {
      log_message("INFO", "All rows already processed - nothing to do")
      # Process and return existing results
      processed_results <- Filter(function(x) !is.null(x), results)
      processed_results <- lapply(processed_results, as.data.frame)
      processed_results <- bind_rows(processed_results)
      processed_results <- processed_results %>%
        mutate(across(gives_credit:credit_bil, as.numeric))
      return(list(annotations = processed_results, new_work_done = FALSE))
    }

    pb <- progress_bar$new(
      format = "Processing [:bar] :percent | ETA: :eta | :current/:total",
      total = nrow(df_in) - start_i + 1  # Adjust total for resume
    )

    log_message("INFO", sprintf("Will process rows %d to %d", start_i, nrow(df_in)))

    # Check if any rows actually need processing (have non-NA text)
    rows_needing_api <- sapply(start_i:nrow(df_in), function(i) {
      txt <- df_in[[input_text]][i]
      !is.na(txt) && txt != "" && txt != "NA"
    })
    
    if (any(rows_needing_api)) {
      # Only check API key if we actually need to make new API calls
      check_api_key()
    } else {
      log_message("INFO", "No rows need API calls - all remaining rows are NA/empty")
    }

    # Track whether any actual API calls were made
    api_calls_made <- FALSE

    for (i in seq(from = start_i, to = nrow(df_in))) {
      log_message("INFO", sprintf("Starting processing of row %d", i))
      txt <- df_in[[input_text]][i]
      if (is.na(txt) || txt == "NA") {
        log_message("INFO", sprintf("Skipping row %d: empty or NA text", i))
        next
      }

      # --- Stage-1 probes
      log_message("INFO", sprintf("Running stage-1 probes for row %d", i))
      # Mark that we're making API calls
      api_calls_made <- TRUE
      ira_response <- safe_chat_completion(
        model = model_stage1,
        messages = list(
          list(role = "system", content = probe_sys("the Inflation Reduction Act")),
          list(role = "user", content = txt)),
        temperature = 0
      )
      ira_flag <- if (is.na(ira_response)) {
        log_message("ERROR", sprintf("IRA probe failed for row %d", i))
        NA_character_
      } else {
        ifelse(grepl("^\\s*YES", ira_response, ignore.case = TRUE), "YES", "NO")
      }

      bil_response <- safe_chat_completion(
        model = model_stage1,
        messages = list(
          list(role = "system", content = probe_sys("the Bipartisan Infrastructure Law")),
          list(role = "user", content = txt)),
        temperature = 0
      )
      bil_flag <- if (is.na(bil_response)) {
        log_message("ERROR", sprintf("BIL probe failed for row %d", i))
        NA_character_
      } else {
        ifelse(grepl("^\\s*YES", bil_response, ignore.case = TRUE), "YES", "NO")
      }

      # ----- build meta block -----
      meta <- c(
        if (!is.na(df_in[[speaker]][i])) paste0("speaker: ", df_in[[speaker]][i]),
        if (!is.na(df_in[[role_col]][i])) paste0("role: ", df_in[[role_col]][i]),
        if (!is.na(df_in[[state_col]][i])) paste0("state: ", df_in[[state_col]][i]),
        if (!is.na(df_in[[district_col]][i])) paste0("district: ", df_in[[district_col]][i]),
        if (!is.na(df_in[[city_col]][i])) paste0("city: ", df_in[[city_col]][i]),
        if (!is.na(df_in[[release_type]][i])) paste0("release_type: ", df_in[[release_type]][i]),
        paste0("ira_funding: ", ira_flag),
        paste0("bil_funding: ", bil_flag)
      )

      prompt <- paste(paste(meta, collapse = "\n"), "###", txt, sep = "\n")

      # ----- Stage-2 main model -----
      response <- safe_chat_completion(
        model = model_stage2,
        messages = list(
          list(role = "system", content = system_prompt),
          list(role = "user", content = prompt)),
        temperature = temp
      )
      
      # Handle JSON parsing with better error handling
      parsed <- if (!is.na(response)) {
        tryCatch({
          parsed_json <- jsonlite::fromJSON(response)
          # Ensure parsed_json is a list and has expected fields
          if (!is.list(parsed_json)) {
            log_message("ERROR", sprintf("Invalid JSON structure for row %d", i))
            NULL
          } else {
            # Ensure all expected fields are present
            expected_fields <- c("gives_credit", "credit_biden", "credit_senate", 
                               "credit_us_rep", "credit_governor", "credit_local",
                               "credit_dem", "credit_gop", "credit_ira", "credit_bil")
            missing_fields <- setdiff(expected_fields, names(parsed_json))
            if (length(missing_fields) > 0) {
              log_message("ERROR", sprintf("Missing fields in JSON for row %d: %s", 
                                         i, paste(missing_fields, collapse = ", ")))
              NULL
            } else {
              parsed_json
            }
          }
        }, error = function(e) {
          log_message("ERROR", sprintf("JSON parsing failed for row %d: %s", i, e$message))
          NULL
        })
      } else {
        NULL
      }

      # Create row data with proper error handling
      if (is.null(parsed)) {
        # If parsing failed, create a row with NAs for all fields
        results[[i]] <- data.frame(
          statement = txt,
          gives_credit = NA_integer_,
          credit_biden = NA_integer_,
          credit_senate = NA_integer_,
          credit_us_rep = NA_integer_,
          credit_governor = NA_integer_,
          credit_local = NA_integer_,
          credit_dem = NA_integer_,
          credit_gop = NA_integer_,
          credit_ira = NA_integer_,
          credit_bil = NA_integer_,
          id = df_in[[id_col]][i]
        )
      } else {
        # Convert parsed data to a data frame
        results[[i]] <- data.frame(
          statement = txt,
          gives_credit = as.integer(parsed$gives_credit),
          credit_biden = as.integer(parsed$credit_biden),
          credit_senate = as.integer(parsed$credit_senate),
          credit_us_rep = as.integer(parsed$credit_us_rep),
          credit_governor = as.integer(parsed$credit_governor),
          credit_local = as.integer(parsed$credit_local),
          credit_dem = as.integer(parsed$credit_dem),
          credit_gop = as.integer(parsed$credit_gop),
          credit_ira = as.integer(parsed$credit_ira),
          credit_bil = as.integer(parsed$credit_bil),
          id = df_in[[id_col]][i]
        )
      }

      # Save intermediate results every N iterations (only if API calls were made)
      if ((i %% save_every) == 0 && api_calls_made) {
        log_message("INFO", sprintf("Saving intermediate results at row %d", i))
        # Create cache directory if it doesn't exist
        if (!dir.exists(cache_dir)) {
          dir.create(cache_dir, recursive = TRUE)
        }
        saveRDS(results, file = save_path)
      }

      Sys.sleep(sleep_sec)
      # After processing each row, log completion
      log_message("INFO", sprintf("Completed processing row %d", i))
      pb$tick()
    }

    # After the loop, log summary
    log_message("INFO", sprintf("Completed processing. Processed %d rows", i - start_i + 1))

    # Save final results only if API calls were made
    if (api_calls_made) {
      log_message("INFO", "Saving final results (API calls were made)")
      # Create cache directory if it doesn't exist
      if (!dir.exists(cache_dir)) {
        dir.create(cache_dir, recursive = TRUE)
      }
      saveRDS(results, file = save_path)
    } else {
      log_message("INFO", "No API calls were made - skipping final results save")
    }

    log_message("INFO", "Processing intermediate results")
    processed_results <- Filter(function(x) !is.null(x), results)
    log_message("INFO", sprintf("Found %d non-null results", length(processed_results)))
    processed_results <- lapply(processed_results, as.data.frame)
    processed_results <- bind_rows(processed_results)
    log_message("INFO", sprintf("Final processed results contain %d rows", nrow(processed_results)))
    processed_results <- processed_results %>%
      mutate(across(gives_credit:credit_bil, as.numeric))

    log_message("INFO", sprintf("Annotation complete. Processed %d statements", nrow(processed_results)))
    list(annotations = processed_results, new_work_done = api_calls_made)
}


# Configuration for robustness checks
# Set to NULL for primary annotation, or specify identifier for robustness check
ROBUSTNESS_CHECK <- NULL  # Change this to run robustness checks (e.g., "strict_codebook")
CUSTOM_CODEBOOK <- NULL   # Change this to use alternative codebook path
CUSTOM_PROBE_TEMPLATE <- NULL  # Change this to use alternative probe template

# Load processed data
log_message("INFO", "Loading processed statements")
processed <- read_csv(here("data", "inter", "statements_processed.csv"), show_col_types = FALSE)

annotation_result <- annotate_statements(
  df_in = processed, 
  reset = new_annotation, 
  sleep_sec = 1,
  codebook_path = CUSTOM_CODEBOOK,
  robustness_check = ROBUSTNESS_CHECK,
  probe_template_path = CUSTOM_PROBE_TEMPLATE
)
annotated <- annotation_result$annotations
new_work_done <- annotation_result$new_work_done

# Post-filter check
annotated_post <- annotated %>%
  # Only allow credit to Democrats if the party is explicitly named
  mutate(
    credit_dem = if_else(
      credit_dem == 1 &
        !grepl("\\bDemocrat(s)?\\b|\\bDemocratic Party\\b",
          statement,
          ignore.case = TRUE
        ),
      0L, credit_dem
    ),
    # Same for Republicans
    credit_gop = if_else(
      credit_gop == 1 &
        !grepl("\\bRepublican(s)?\\b|\\bRepublican Party\\b",
          statement,
          ignore.case = TRUE
        ),
      0L, credit_gop
    ),
    # Only allow credit to IRA if the IRA is explicitly mentioned
    credit_ira = if_else(
      credit_ira == 1 &
        !grepl("IRA|Inflation Reduction Act", statement, ignore.case = TRUE),
      0L, credit_ira
    ),
    # Same for the Bipartisan Infrastructure Law
    credit_bil = if_else(
      credit_bil == 1 &
        !grepl("Bipartisan Infrastructure Law|Infrastructure Investment and Jobs Act|BIL|IIJA", statement, ignore.case = TRUE),
      0L, credit_bil
    ),
    # Only allow credit to Biden if the president is explicitly named
    credit_biden = if_else(
      credit_biden == 1 &
        !grepl("(?<!Vice )President\\b|Biden\\b|White House\\b", statement, ignore.case = TRUE, perl = TRUE),
      0L, credit_biden
    )
  )

# Final processing
annotated_post <- annotated_post %>%
    dplyr::select(-statement) %>%
    dplyr::rename(statement_id = id)

# Save to data/inter for use by subsequent scripts
if (!is.null(ROBUSTNESS_CHECK)) {
  inter_filename <- sprintf("annotated_statements_%s.csv", ROBUSTNESS_CHECK)
  cache_filename <- sprintf("annotated_statements_%s.csv", ROBUSTNESS_CHECK)
} else {
  inter_filename <- "annotated_statements.csv"
  cache_filename <- "annotated_statements.csv"
}

write_csv(annotated_post, here("data", "inter", inter_filename))
message(sprintf("Saved %s", here("data", "inter", inter_filename)))

# Only save to api_cache if new API calls were made
if (new_work_done) {
  message("New annotations were created - saving to api_cache for archival")
  date_folder <- format(Sys.Date(), "%Y%m%d")
  
  # Use the same directory structure as defined earlier
  if (!is.null(ROBUSTNESS_CHECK)) {
    cache_dir <- here("data", "cache", "annotations", "robustness", ROBUSTNESS_CHECK, date_folder)
  } else {
    cache_dir <- here("data", "cache", "annotations", date_folder)
  }
  
  if (!dir.exists(cache_dir)) {
    dir.create(cache_dir, recursive = TRUE)
  }
  write_csv(annotated_post, file.path(cache_dir, cache_filename))
  message(sprintf("Saved %s", file.path(cache_dir, cache_filename)))
} else {
  message("No new annotations were created - skipping api_cache save")
}

message("[OK] Annotation complete")